`include    "aes_param.vh"

module aes_col_mixer(

    input                   clk,
    input                   rst_n,

    input                   mode,   //0:encode  1:decode
    input   [15:0][7:0]     matrix,

    output  [15:0][7:0]     matrix_m
);


wire    [3:0][31:0]     matrix_row;
wire    [15:0][7:0]     matrix_map;

reg     [15:0][7:0]     matrix_m;

assign matrix_row[0][31:0] = {matrix[3],matrix[2],matrix[1],matrix[0]};
assign matrix_row[1][31:0] = {matrix[7],matrix[6],matrix[5],matrix[4]};
assign matrix_row[2][31:0] = {matrix[11],matrix[10],matrix[9],matrix[8]};
assign matrix_row[3][31:0] = {matrix[15],matrix[14],matrix[13],matrix[12]};


assign matrix_map[0] = func_matrix_one_compute(matrix_row[0],COL_MIXER_FACTOR__0);
assign matrix_map[1] = func_matrix_one_compute(matrix_row[1],COL_MIXER_FACTOR__0);
assign matrix_map[2] = func_matrix_one_compute(matrix_row[2],COL_MIXER_FACTOR__0);
assign matrix_map[3] = func_matrix_one_compute(matrix_row[3],COL_MIXER_FACTOR__0);

assign matrix_map[4] = func_matrix_one_compute(matrix_row[0],COL_MIXER_FACTOR__1);
assign matrix_map[5] = func_matrix_one_compute(matrix_row[1],COL_MIXER_FACTOR__1);
assign matrix_map[6] = func_matrix_one_compute(matrix_row[2],COL_MIXER_FACTOR__1);
assign matrix_map[7] = func_matrix_one_compute(matrix_row[3],COL_MIXER_FACTOR__1);

assign matrix_map[8] = func_matrix_one_compute(matrix_row[0],COL_MIXER_FACTOR__2);
assign matrix_map[9] = func_matrix_one_compute(matrix_row[1],COL_MIXER_FACTOR__2);
assign matrix_map[10]= func_matrix_one_compute(matrix_row[2],COL_MIXER_FACTOR__2);
assign matrix_map[11]= func_matrix_one_compute(matrix_row[3],COL_MIXER_FACTOR__2);

assign matrix_map[12]= func_matrix_one_compute(matrix_row[0],COL_MIXER_FACTOR__3);
assign matrix_map[13]= func_matrix_one_compute(matrix_row[1],COL_MIXER_FACTOR__3);
assign matrix_map[14]= func_matrix_one_compute(matrix_row[2],COL_MIXER_FACTOR__3);
assign matrix_map[15]= func_matrix_one_compute(matrix_row[3],COL_MIXER_FACTOR__3);


always @(posedge clk or negedge rst_n)begin
    if(!rst_n)
        for(integer i=0;i<16;i++)
            matrix_m[i] <= 8'h0;
    else
        for(integer i=0;i<16;i++)
            matrix_m[i] <= matrix_map[i];
end


endmodule
